from src.utils.load_models import load_client_models
from src.utils.load_models import load_models, load_model3, load_model_qktD

import pytorch_lightning as pl
import torch
import logging
import torch.nn.functional as F
from hydra.utils import instantiate
from torch import nn
from torchmetrics import MaxMetric
from torchmetrics import Accuracy
from torchmetrics import ConfusionMatrix
import wandb
import numpy as np
from typing import Union, List
import copy
import os

log = logging.getLogger(__name__)

class Distilltion(pl.LightningModule):

    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.num_classes = cfg.num_classes
        self.test_description = self.hparams.test_description

        self.learner_model: nn.Module = ...
        self.teacher_model: Union[nn.Module, List[nn.Module]] = ...

        # Load teacher & client weights

        self.learner_model, self.teacher_model = load_client_models(
                num_clients=self.hparams.num_clients, learner_client=self.hparams.learner_client,
                teacher_client=self.hparams.teacher_client, clients=self.hparams.clients,
                learner_model=instantiate(self.hparams.model), teacher_model=instantiate(self.hparams.model),
                device=self.device
            )

        self.learner_model = self.learner_model.to(self.device)
        if isinstance(self.teacher_model, list):
            self.teacher_model = [t_model.to(self.device) for t_model in self.teacher_model]
        else:
            self.teacher_model = self.teacher_model.to(self.device)

        self.org_learner_model_copy = copy.deepcopy(self.learner_model)
        self.org_learner_model_copy.requires_grad_(False)

        self.train_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)
        self.val_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)
        self.test_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)

        # classification loss
        self.cls_loss = nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.per_class_test_acc = []
        self.per_class_val_acc = []

        self.private_val_acc = 0
        self.server_test_acc = 0

        # for logging best so far validation accuracy
        self.val_acc_best = MaxMetric()

        self.T = self.hparams.KL_temperature
        self.no_alpha = self.hparams.no_alpha

        # freeze teacher
        if self.hparams.multiple_teachers or self.hparams.qkt_multi_teachers or isinstance(self.teacher_model, list):
            for t_model in self.teacher_model:
                t_model.requires_grad_(False)
                t_model.eval()
            for t_model in self.teacher_model:
                assert not t_model.training
        else:
            self.teacher_model.requires_grad_(False)
            self.teacher_model.eval()

        # exp name
        self.exp_name = f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}"

    def on_fit_start(self):
        self.learner_model = self.learner_model.to(self.device)
        if isinstance(self.teacher_model, list):
            self.teacher_model = [t_model.to(self.device) for t_model in self.teacher_model]
        else:
            self.teacher_model = self.teacher_model.to(self.device)
        self.org_learner_model_copy = self.org_learner_model_copy.to(self.device)

    def forward(self, x):
        return self.learner_model(x)

    def divergence(self, student_logits, teacher_logits, alpha=None):
        divergence = F.kl_div(
            F.log_softmax(student_logits / self.hparams.KL_temperature, dim=1),
            F.softmax(teacher_logits / self.hparams.KL_temperature, dim=1),
            reduction='batchmean'
        )
        if not self.no_alpha:
            if alpha:
                divergence = divergence * alpha

        if not self.hparams.not_multiply_T:
            divergence = divergence * self.hparams.KL_temperature * self.hparams.KL_temperature

        return divergence

    def step(self, batch):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        # print(f"len(y.shape): {len(y.shape)}")
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)
            # print(f"AFTER: len(y.shape): {len(y.shape)}")

        logits = self(x)
        loss = self.cls_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

        logits = self(x)
        kl_loss = 0

        alpha = np.random.uniform(low=0.25) if self.hparams.random_alpha else self.hparams.KL_loss_strength
        self.alpha = alpha

        if not self.no_alpha:
            cls_loss = self.cls_loss(logits, y) * (1 - alpha)
        else:
            cls_loss = self.cls_loss(logits, y)

        if self.hparams.multiple_teachers or self.hparams.qkt_multi_teachers or (isinstance(self.teacher_model, list) and len(self.teacher_model) > 1):
            for t_model in self.teacher_model:
                with torch.no_grad():
                    teacher_logits = t_model(x)  # Ensure x is on the correct device

                divergence = self.divergence(student_logits=logits, teacher_logits=teacher_logits)
                kl_loss += divergence

        else: # single teacher
            with torch.no_grad():
                teacher_logits = self.teacher_model(x)  # Ensure x is on the correct device

            kl_loss = self.divergence(student_logits=logits, teacher_logits=teacher_logits, alpha=alpha)

        self.log(
            f"{self.exp_name}/learner-ce_loss",
            cls_loss, on_step=True, on_epoch=False, prog_bar=True
        )
        self.log(
            f"{self.exp_name}/learner-kl_loss",
            kl_loss, on_step=True, on_epoch=False, prog_bar=True
        )

        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, y)
        self.log(
            f"{self.exp_name}/train_acc",
            acc, on_step=False, on_epoch=True, prog_bar=False
        )

        return {"loss": kl_loss + cls_loss, "preds": preds, "targets": y}

    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        acc = self.val_acc(preds, targets)
        self.val_confusion_matrix(preds, targets)
        conf_mat = self.val_confusion_matrix(preds, targets)
        self.log(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_cls_loss",
            loss, on_step=False, on_epoch=True, prog_bar=False
        )
        self.log(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_acc",
            acc, on_step=False, on_epoch=True, prog_bar=True
        )

        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_epoch_end(self, outputs):
        acc = self.val_acc.compute()
        self.val_acc_best.update(acc)
        self.log(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/best_val_acc",
            self.val_acc_best.compute(), on_epoch=True, prog_bar=True
        )

        confusion_matrix = self.val_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())
        self.private_val_acc = acc

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_acc"
        ] = acc

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_per_class_acc"
        ] = self.per_class_val_acc

        log.info(f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_acc: {acc}")
        log.info(f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/val_per_class_acc: {self.per_class_val_acc}")

    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        acc = self.test_acc(preds, targets)
        self.test_confusion_matrix(preds, targets)

        self.log(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/test_cls_loss",
            loss, on_step=False, on_epoch=True
        )
        self.log(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/test_acc",
            acc, on_step=False, on_epoch=True
        )

        return {"loss": loss, "preds": preds, "targets": targets}

    def test_epoch_end(self, outputs):
        acc = self.test_acc.compute()

        self.server_test_acc = acc

        confusion_matrix = self.test_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_test_acc = np.diag(confusion_matrix.cpu().detach().numpy())

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/{self.test_description}_per_class_test_acc"
        ] = self.per_class_test_acc

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/{self.test_description}_test_acc"
        ] = acc

        log.info(f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/{self.test_description}_test_acc: {acc}")
        log.info(
            f"client-{self.hparams.learner_client}_from_client-{self.hparams.teacher_client}/{self.test_description}_per_class_test_acc: {self.per_class_test_acc}")

    def on_epoch_end(self):
        self.train_acc.reset()
        self.test_acc.reset()
        self.val_acc.reset()
        self.train_confusion_matrix.reset()
        self.val_confusion_matrix.reset()
        self.test_confusion_matrix.reset()

    def on_test_end(self):
        self.val_acc_best.reset()

    # def configure_optimizers(self):
    #     learner_optim = instantiate(config=self.hparams.optim.optim, params=self.learner_model.parameters())
    #     return learner_optim

    def configure_optimizers(self):
        learner_optim = instantiate(config=self.hparams.optim.optim, params=self.learner_model.parameters())
        scheduler = None
        if self.hparams.optim.scheduler:
            from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

            scheduler_config = {
                "first_cycle_steps": self.hparams.optim.scheduler_params.first_cycle_steps,
                "max_lr": learner_optim.param_groups[0]['lr'],
                "min_lr": self.hparams.optim.scheduler_params.min_lr,
                "warmup_steps": self.hparams.optim.scheduler_params.warmup_steps,
            }
            scheduler = CosineAnnealingWarmupRestarts(learner_optim, **scheduler_config)
        return [learner_optim], [scheduler] if scheduler else []
